{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Optional, Tuple, Union\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torch import nn\n",
    "import torch\n",
    "\n",
    "class BartClassificationHead(nn.Module):\n",
    "    \"\"\"Head for sentence-level classification tasks.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        inner_dim: int,\n",
    "        num_classes: int,\n",
    "        pooler_dropout: float,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(input_dim, inner_dim)\n",
    "        self.dropout = nn.Dropout(p=pooler_dropout)\n",
    "        self.out_proj = nn.Linear(inner_dim, num_classes)\n",
    "\n",
    "    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = torch.tanh(hidden_states)\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.out_proj(hidden_states)\n",
    "        return hidden_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T5Tagging(T5ForConditionalGeneration):\n",
    "    def __init__(self, config: T5Config, **kwargs):\n",
    "        super().__init__(config, **kwargs)\n",
    "        self.classification_head = BartClassificationHead(\n",
    "            config.d_model,\n",
    "            config.d_model,\n",
    "            config.num_labels,\n",
    "            config.dropout_rate,\n",
    "        )\n",
    "        self._init_weights(self.classification_head.dense)\n",
    "        self._init_weights(self.classification_head.out_proj)\n",
    "    \n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids: Optional[torch.LongTensor] = None,\n",
    "        attention_mask: Optional[torch.FloatTensor] = None,\n",
    "        decoder_input_ids: Optional[torch.LongTensor] = None,\n",
    "        decoder_attention_mask: Optional[torch.BoolTensor] = None,\n",
    "        head_mask: Optional[torch.FloatTensor] = None,\n",
    "        decoder_head_mask: Optional[torch.FloatTensor] = None,\n",
    "        cross_attn_head_mask: Optional[torch.Tensor] = None,\n",
    "        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n",
    "        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        labels: Optional[torch.LongTensor] = None,\n",
    "        labels_tag = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "    ):\n",
    "        outputs = super().forward(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            decoder_input_ids=decoder_input_ids,\n",
    "            decoder_attention_mask=decoder_attention_mask,\n",
    "            head_mask=head_mask,\n",
    "            decoder_head_mask=decoder_head_mask,\n",
    "            cross_attn_head_mask=cross_attn_head_mask,\n",
    "            encoder_outputs=encoder_outputs,\n",
    "            past_key_values=past_key_values,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            decoder_inputs_embeds=decoder_inputs_embeds,\n",
    "            labels=labels,\n",
    "            use_cache=use_cache,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=True,\n",
    "            return_dict=True,\n",
    "        )\n",
    "        \n",
    "        last_layer = outputs.decoder_hidden_states[-1]\n",
    "        logits = self.classification_head(last_layer)\n",
    "        if labels_tag is not None:\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            loss = loss_fct(logits.view(-1, self.config.num_labels), labels_tag.view(-1))\n",
    "            outputs.loss += loss\n",
    "            \n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['finetune-t5-small-standard-bahasa-cased/checkpoint-290000',\n",
       " 'finetune-t5-small-standard-bahasa-cased/checkpoint-300000',\n",
       " 'finetune-t5-small-standard-bahasa-cased/checkpoint-310000']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "checkpoints = sorted(glob('finetune-t5-small-standard-bahasa-cased/checkpoint-*'))\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = T5Tagging.from_pretrained(checkpoints[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "data = []\n",
    "with open('test.jsonl') as fopen:\n",
    "    for l in fopen:\n",
    "        data.append(json.loads(l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4994"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Cognocoli-Monticchi ialah komun bak jabatan Corse-du-Sud di kepulauan Corsica ; Perancis ?'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = ' '.join([d[0] for d in data[1][1]])\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = [{'input_ids': tokenizer.encode(f'kesalahan tatabahasa:{s}', return_tensors='pt')[\n",
    "    0]} for s in [x]]\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded,do_sample=True, \n",
    "    max_length=50, \n",
    "    top_k=0,num_return_sequences=3,\n",
    "              output_attentions = True,\n",
    "                        output_hidden_states = True,\n",
    "                        output_scores = True,\n",
    "                        return_dict_in_generate = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['sequences', 'scores', 'encoder_attentions', 'encoder_hidden_states', 'decoder_attentions', 'cross_attentions', 'decoder_hidden_states'])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Cognocoli-Monticchi ialah komun di jabatan Corse-du-Sud di kepulauan Corsica, Perancis.</s>']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.batch_decode(outputs.sequences[:,1:], )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 31])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.sequences.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(outputs.decoder_hidden_states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 30, 512])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "last_layer = torch.stack([o[-1] for o in outputs.decoder_hidden_states])[:,:,0]\n",
    "last_layer = last_layer.transpose(0, 1)\n",
    "last_layer.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 30)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = model.classification_head(last_layer).detach().numpy().argmax(axis = -1)\n",
    "t.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from malaya.text.bpe import merge_sentencepiece_tokens_tagging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(outputs.sequences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = outputs.sequences[:,1:][0].detach().numpy()\n",
    "s = tokenizer.convert_ids_to_tokens(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(30, 3)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(s), len(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from malaya.text.bpe import SPECIAL_TOKENS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁saya', '▁suka']"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.tokenize('saya suka')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_sentencepiece_tokens_tagging(x, y, model='bert', rejected = None, **kwargs):\n",
    "    new_paired_tokens = []\n",
    "    n_tokens = len(x)\n",
    "    if rejected is None:\n",
    "        rejected = list(SPECIAL_TOKENS[model].values())\n",
    "\n",
    "    i = 0\n",
    "\n",
    "    while i < n_tokens:\n",
    "\n",
    "        current_token, current_label = x[i], y[i]\n",
    "\n",
    "        if isinstance(current_token, bytes):\n",
    "            current_token = current_token.decode()\n",
    "        if not current_token.startswith('▁') and current_token not in rejected and i > 0:\n",
    "            previous_token, previous_label = new_paired_tokens.pop()\n",
    "            merged_token = previous_token\n",
    "            merged_label = [previous_label]\n",
    "            while (\n",
    "                not current_token.startswith('▁')\n",
    "                and current_token not in rejected\n",
    "            ):\n",
    "                merged_token = merged_token + current_token.replace('▁', '')\n",
    "                merged_label.append(current_label)\n",
    "                i = i + 1\n",
    "                current_token, current_label = x[i], y[i]\n",
    "            merged_label = merged_label[0]\n",
    "            new_paired_tokens.append((merged_token, merged_label))\n",
    "\n",
    "        else:\n",
    "            new_paired_tokens.append((current_token, current_label))\n",
    "            i = i + 1\n",
    "\n",
    "    words = [\n",
    "        i[0].replace('▁', '') for i in new_paired_tokens if i[0] not in rejected\n",
    "    ]\n",
    "    labels = [i[1] for i in new_paired_tokens if i[0] not in rejected]\n",
    "    return words, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Cognocoli-Monticchi', 2),\n",
       " ('ialah', 2),\n",
       " ('komun', 2),\n",
       " ('di', 9),\n",
       " ('jabatan', 2),\n",
       " ('Corse-du-Sud', 2),\n",
       " ('di', 2),\n",
       " ('kepulauan', 2),\n",
       " ('Corsica', 2),\n",
       " (',', 14),\n",
       " ('Perancis', 2),\n",
       " ('.', 14)]"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merged = merge_sentencepiece_tokens_tagging(\n",
    "    s, t[0], rejected = tokenizer.all_special_tokens\n",
    ")\n",
    "list(zip(merged[0], merged[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Cloning https://huggingface.co/mesolitica/finetune-tatabahasa-t5-small-standard-bahasa-cased into local empty directory.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1fa10a113d04cf796600c47837164df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file pytorch_model.bin:   0%|          | 4.00k/232M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-tatabahasa-t5-small-standard-bahasa-cased\n",
      "   7ef1414..e7ed1a3  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-tatabahasa-t5-small-standard-bahasa-cased/commit/e7ed1a3d1fe6636cadc2ef36807a5fbeb1690140'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('finetune-tatabahasa-t5-small-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7473406aa87245f9beb58490f5f91db1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file spiece.model:   1%|          | 4.00k/784k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-tatabahasa-t5-small-standard-bahasa-cased\n",
      "   e7ed1a3..49122ef  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-tatabahasa-t5-small-standard-bahasa-cased/commit/49122ef664b14ab37f70b8b690b759e127689677'"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.push_to_hub('finetune-tatabahasa-t5-small-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "def compute_exact(a_gold, a_pred):\n",
    "    return int(a_gold == a_pred)\n",
    "\n",
    "\n",
    "def compute_f1(a_gold, a_pred):\n",
    "    gold_toks = a_gold\n",
    "    pred_toks = a_pred\n",
    "    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)\n",
    "    num_same = sum(common.values())\n",
    "    if len(gold_toks) == 0 or len(pred_toks) == 0:\n",
    "        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise\n",
    "        return int(gold_toks == pred_toks)\n",
    "    if num_same == 0:\n",
    "        return 0\n",
    "    precision = 1.0 * num_same / len(pred_toks)\n",
    "    recall = 1.0 * num_same / len(gold_toks)\n",
    "    f1 = (2 * precision * recall) / (precision + recall)\n",
    "    return f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 4994/4994 [13:30<00:00,  6.16it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "exact_match, f1, exact_match_tag, f1_tag = [], [], [], []\n",
    "for i in tqdm(range(len(data))):\n",
    "    x = ' '.join([d[0] for d in data[i][1]])\n",
    "    y = [d[0] for d in data[i][0]]\n",
    "    tag = [d[1] for d in data[i][1]]\n",
    "    input_ids = [{'input_ids': tokenizer.encode(f'kesalahan tatabahasa:{s}', return_tensors='pt')[\n",
    "        0]} for s in [x]]\n",
    "    padded = tokenizer.pad(input_ids, padding='longest')\n",
    "    outputs = model.generate(**padded,max_length=256, output_attentions = True,\n",
    "                            output_hidden_states = True,\n",
    "                            output_scores = True,\n",
    "                            return_dict_in_generate = True)\n",
    "    last_layer = torch.stack([o[-1] for o in outputs.decoder_hidden_states])[:,:,0]\n",
    "    last_layer = last_layer.transpose(0, 1)\n",
    "    t = model.classification_head(last_layer).detach().numpy().argmax(axis = -1)\n",
    "    s = outputs.sequences[:,1:][0].detach().numpy()\n",
    "    s = tokenizer.convert_ids_to_tokens(s)\n",
    "    merged = merge_sentencepiece_tokens_tagging(\n",
    "        s, t[0], rejected = tokenizer.all_special_tokens\n",
    "    )\n",
    "    \n",
    "    exact_match.append(compute_exact(y, merged[0]))\n",
    "    exact_match_tag.append(compute_exact(tag, merged[1]))\n",
    "    f1.append(compute_f1(y, merged[0]))\n",
    "    f1_tag.append(compute_f1(tag, merged[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.8145774929915899, 0.89447336804165)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(exact_match), np.mean(exact_match_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9781278051739509, 0.9905973774727378)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(f1), np.mean(f1_tag)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
